from __future__ import annotations
import typing as t

import openllm_core

if t.TYPE_CHECKING:
  import transformers

INSTRUCTION_KEY = '### Instruction:'
RESPONSE_KEY = '### Response:'
END_KEY = '### End'
INTRO_BLURB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.'


def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int:
  token_ids = tokenizer.encode(key)
  if len(token_ids) > 1:
    raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
  return token_ids[0]


class DollyV2Config(openllm_core.LLMConfig):
  """Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.

  Based on pythia-12b, Dolly is trained on ~15k instruction/response fine tuning records databricks-dolly-15k
  generated by Databricks employees in capability domains from the InstructGPT paper, including brainstorming,
  classification, closed QA, generation, information extraction, open QA and summarization.

  dolly-v2-12b is not a state-of-the-art model, but does exhibit surprisingly high quality instruction
  following behavior not characteristic of the foundation model on which it is based.

  Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
  """

  __config__ = {
    'timeout': 3600000,
    'url': 'https://github.com/databrickslabs/dolly',
    'architecture': 'GPTNeoXForCausalLM',
    'default_id': 'databricks/dolly-v2-3b',
    'model_ids': ['databricks/dolly-v2-3b', 'databricks/dolly-v2-7b', 'databricks/dolly-v2-12b'],
  }

  class GenerationConfig:
    temperature: float = 0.9
    top_p: float = 0.92
    top_k: int = 5
    max_new_tokens: int = 256
    eos_token_id: int = 50277  # NOTE: from get_special_token_id(self.tokenizer, END_KEY)

  @property
  def template(self):
    return '{intro}\n{instruction_key}\n{instruction}\n{response_key}\n'.format(
      intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY
    )
